"""Modified from https://github.com/kijai/ComfyUI-EasyAnimateWrapper/blob/main/nodes.py
"""
import copy
import gc
import inspect
import json
import os

import comfy.model_management as mm
import cv2
import folder_paths
import numpy as np
import torch
from comfy.utils import ProgressBar, load_torch_file
from diffusers import FlowMatchEulerDiscreteScheduler
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image

from ...videox_fun.data.bucket_sampler import (ASPECT_RATIO_512,
                                               get_closest_ratio)
from ...videox_fun.data.dataset_image_video import process_pose_params
from ...videox_fun.models import (AutoencoderKLWan, AutoencoderKLWan3_8,
                                  AutoTokenizer, CLIPModel,
                                  VaceWanTransformer3DModel, WanT5EncoderModel)
from ...videox_fun.models.cache_utils import get_teacache_coefficients
from ...videox_fun.pipeline import Wan2_2VaceFunPipeline
from ...videox_fun.utils.fp8_optimization import (
    convert_model_weight_to_float8, convert_weight_dtype_wrapper, undo_convert_weight_dtype_wrapper,
    replace_parameters_by_name)
from ...videox_fun.utils.lora_utils import merge_lora, unmerge_lora
from ...videox_fun.utils.utils import (filter_kwargs, get_image_latent,
                                       get_image_to_video_latent,
                                       get_video_to_video_latent, get_autocast_dtype)
from ..comfyui_utils import (script_directory,
                             search_model_in_possible_folders, to_pil)
from ..wan2_1.nodes import get_wan_scheduler

# Used in lora cache
transformer_cpu_cache       = {}
transformer_high_cpu_cache  = {}
# lora path before
lora_path_before            = ""
lora_high_path_before       = ""

class LoadVaceWanTransformer3DModel:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model_name": (
                    folder_paths.get_filename_list("diffusion_models"),
                    {"default": "Wan2_1-T2V-1_3B_bf16.safetensors,"},
                ),
                "precision": (["fp16", "bf16"],
                    {"default": "bf16"}
                ),
            },
        }
    RETURN_TYPES = ("TransformerModel", "STRING")
    RETURN_NAMES = ("transformer", "model_name")
    FUNCTION    = "loadmodel"
    CATEGORY    = "CogVideoXFUNWrapper"

    def loadmodel(self, model_name, precision):
        # Init weight_dtype and device
        device          = mm.get_torch_device()
        offload_device  = mm.unet_offload_device()
        weight_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}[precision]

        mm.unload_all_models()
        mm.cleanup_models()
        mm.soft_empty_cache()
        transformer = None

        model_path = folder_paths.get_full_path("diffusion_models", model_name)
        transformer_state_dict = load_torch_file(model_path, safe_load=True)
        
        eps             = 1e-6
        text_len        = 512
        freq_dim        = 256
        dim             = transformer_state_dict["patch_embedding.weight"].shape[0]
        hidden_size     = dim
        in_dim          = transformer_state_dict["patch_embedding.weight"].shape[1]
        in_channels     = in_dim
        ffn_dim         = transformer_state_dict["blocks.0.ffn.0.bias"].shape[0]
        vace_layers     = [0, 5, 10, 15, 20, 25, 30, 35]
        vace_in_dim     = 96

        add_ref_conv            = True if "ref_conv.weight" in transformer_state_dict else False
        in_dim_ref_conv         = transformer_state_dict["ref_conv.weight"].shape[1] if "ref_conv.weight" in transformer_state_dict else None
        add_control_adapter     = True if "control_adapter.conv.weight" in transformer_state_dict else False
        in_dim_control_adapter  = transformer_state_dict["control_adapter.conv.weight"].shape[1] if "control_adapter.conv.weight" in transformer_state_dict else None

        if dim == 5120:
            num_heads = 40
            num_layers = 40
            out_dim = 16
            downscale_factor_control_adapter = 8
            model_name_in_pipeline = "wan2.2-vace-fun-a14b"
                
        elif dim == 3072:
            num_heads = 24
            num_layers = 30
            out_dim = 48
            downscale_factor_control_adapter = 16
            model_name_in_pipeline = "wan2.2-vace-fun-5b"
        else: 
            num_heads = 12
            num_layers = 30
            out_dim = 16
            downscale_factor_control_adapter = 8
            model_name_in_pipeline = "wan2.1-vace-1.3b"
        
        model_type = "t2v"

        kwargs = dict(
            dim = dim,
            in_dim = in_dim,
            eps = eps,
            ffn_dim = ffn_dim,
            freq_dim = freq_dim,
            model_type = model_type,
            num_heads = num_heads,
            num_layers = num_layers,
            out_dim = out_dim,
            text_len = text_len,
            in_channels = in_channels,
            hidden_size = hidden_size,
            add_control_adapter = add_control_adapter,
            add_ref_conv = add_ref_conv,
            in_dim_control_adapter = in_dim_control_adapter // downscale_factor_control_adapter // downscale_factor_control_adapter if in_dim_control_adapter is not None else in_dim_control_adapter,
            in_dim_ref_conv = in_dim_ref_conv,
            downscale_factor_control_adapter = downscale_factor_control_adapter,
            vace_layers = vace_layers,
            vace_in_dim = vace_in_dim
        )

        sig = inspect.signature(VaceWanTransformer3DModel)
        accepted = {k: v for k, v in kwargs.items() if k in sig.parameters}
        transformer = VaceWanTransformer3DModel(**accepted)
        transformer.load_state_dict(transformer_state_dict)
        transformer = transformer.eval().to(device=offload_device, dtype=weight_dtype)
        return (transformer, model_name_in_pipeline)

class CombineWan2_2VaceFunPipeline:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "transformer": ("TransformerModel",),
                "vae": ("VAEModel",),
                "text_encoder": ("TextEncoderModel",),
                "tokenizer": ("Tokenizer",),
                "model_name": ("STRING",),
                "GPU_memory_mode":(
                    ["model_full_load", "model_full_load_and_qfloat8","model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"],
                    {
                        "default": "model_cpu_offload",
                    }
                ),
            },
            "optional":{
                "clip_encoder": ("ClipEncoderModel",),
                "transformer_2": ("TransformerModel",),
            },
        }

    RETURN_TYPES = ("FunModels",)
    RETURN_NAMES = ("funmodels",)
    FUNCTION = "loadmodel"
    CATEGORY = "CogVideoXFUNWrapper"

    def loadmodel(self, model_name, GPU_memory_mode, transformer, vae, text_encoder, tokenizer, clip_encoder=None, transformer_2=None, model_type="Control"):
        weight_dtype    = transformer.dtype if transformer.dtype not in [torch.float32, torch.float8_e4m3fn, torch.float8_e5m2] else get_autocast_dtype()
        device          = mm.get_torch_device()
        offload_device  = mm.unet_offload_device()

        # Get pipeline
        pipeline = Wan2_2VaceFunPipeline(
            vae=vae,
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            transformer=transformer,
            transformer_2=transformer_2,
            scheduler=None,
        )

        pipeline.remove_all_hooks()
        undo_convert_weight_dtype_wrapper(transformer)
        pipeline.to(device=offload_device)
        transformer = transformer.to(weight_dtype)

        if GPU_memory_mode == "sequential_cpu_offload":
            replace_parameters_by_name(transformer, ["modulation",], device=device)
            transformer.freqs = transformer.freqs.to(device=device)
            pipeline.enable_sequential_cpu_offload()
        elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
            convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",])
            convert_weight_dtype_wrapper(transformer, weight_dtype)
            pipeline.enable_model_cpu_offload()
        elif GPU_memory_mode == "model_cpu_offload":
            pipeline.enable_model_cpu_offload()
        elif GPU_memory_mode == "model_full_load_and_qfloat8":
            convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",])
            convert_weight_dtype_wrapper(transformer, weight_dtype)
            pipeline.to(device=device)
        else:
            pipeline.to(device)

        funmodels = {
            'pipeline': pipeline, 
            'dtype': weight_dtype,
            'model_name': model_name,
            'model_type': model_type,
            'loras': [],
            'strength_model': []
        }
        return (funmodels,)


class LoadWan2_2VaceFunModel:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": (
                    [ 
                        'Wan2.2-VACE-Fun-A14B',
                    ],
                    {
                        "default": 'Wan2.2-VACE-Fun-A14B',
                    }
                ),
                "GPU_memory_mode":(
                    ["model_full_load", "model_full_load_and_qfloat8", "model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"],
                    {
                        "default": "model_cpu_offload",
                    }
                ),
                "config": (
                    [
                        "wan2.2/wan_civitai_t2v.yaml",
                    ],
                    {
                        "default": "wan2.2/wan_civitai_t2v.yaml",
                    }
                ),
                "precision": (
                    ['fp16', 'bf16'],
                    {
                        "default": 'bf16'
                    }
                ),
            },
        }
    RETURN_TYPES = ("FunModels",)
    RETURN_NAMES = ("funmodels",)
    FUNCTION = "loadmodel"
    CATEGORY = "CogVideoXFUNWrapper"

    def loadmodel(self, GPU_memory_mode, model, precision, config, model_type="Control"):
        # Init weight_dtype and device
        device          = mm.get_torch_device()
        offload_device  = mm.unet_offload_device()
        weight_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]

        mm.unload_all_models()
        mm.cleanup_models()
        mm.soft_empty_cache()

        # Init processbar
        pbar = ProgressBar(5)

        # Load config
        config_path = f"{script_directory}/config/{config}"
        config = OmegaConf.load(config_path)

        # Detect model is existing or not
        possible_folders = ["CogVideoX_Fun", "Fun_Models", "VideoX_Fun"] + \
                [os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models/Diffusion_Transformer")] # Possible folder names to check
        # Initialize model_name as None
        model_name = search_model_in_possible_folders(possible_folders, model)

        Chosen_AutoencoderKL = {
            "AutoencoderKLWan": AutoencoderKLWan,
            "AutoencoderKLWan3_8": AutoencoderKLWan3_8
        }[config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')]
        vae = Chosen_AutoencoderKL.from_pretrained(
            os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')),
            additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
        ).to(weight_dtype)
        # Update pbar
        pbar.update(1)

        # Load Sampler
        print("Load Sampler.")
        scheduler = FlowMatchEulerDiscreteScheduler(
            **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs']))
        )
        # Update pbar
        pbar.update(1)
        
        # Get Transformer
        transformer = VaceWanTransformer3DModel.from_pretrained(
            os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')),
            transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
            low_cpu_mem_usage=True,
            torch_dtype=weight_dtype,
        )

        if config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe":
            transformer_2 = VaceWanTransformer3DModel.from_pretrained(
                os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')),
                transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
                low_cpu_mem_usage=True,
                torch_dtype=weight_dtype,
            )
        else:
            transformer_2 = None
        # Update pbar
        pbar.update(1) 

        # Get tokenizer and text_encoder
        tokenizer = AutoTokenizer.from_pretrained(
            os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
        )
        pbar.update(1) 

        text_encoder = WanT5EncoderModel.from_pretrained(
            os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
            additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
            low_cpu_mem_usage=True,
            torch_dtype=weight_dtype,
        )
        pbar.update(1) 

        # Get pipeline
        pipeline = Wan2_2VaceFunPipeline(
            vae=vae,
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            transformer=transformer,
            transformer_2=transformer_2,
            scheduler=scheduler,
        )

        pipeline.remove_all_hooks()
        undo_convert_weight_dtype_wrapper(transformer)

        if GPU_memory_mode == "sequential_cpu_offload":
            replace_parameters_by_name(transformer, ["modulation",], device=device)
            transformer.freqs = transformer.freqs.to(device=device)
            if transformer_2 is not None:
                replace_parameters_by_name(transformer_2, ["modulation",], device=device)
                transformer_2.freqs = transformer_2.freqs.to(device=device)
            pipeline.enable_sequential_cpu_offload(device=device)
        elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
            convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device)
            convert_weight_dtype_wrapper(transformer, weight_dtype)
            if transformer_2 is not None:
                convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device)
                convert_weight_dtype_wrapper(transformer_2, weight_dtype)
            pipeline.enable_model_cpu_offload(device=device)
        elif GPU_memory_mode == "model_cpu_offload":
            pipeline.enable_model_cpu_offload(device=device)
        elif GPU_memory_mode == "model_full_load_and_qfloat8":
            convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device)
            convert_weight_dtype_wrapper(transformer, weight_dtype)
            if transformer_2 is not None:
                convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device)
                convert_weight_dtype_wrapper(transformer_2, weight_dtype)
            pipeline.to(device=device)
        else:
            pipeline.to(device=device)

        funmodels = {
            'pipeline': pipeline, 
            'dtype': weight_dtype,
            'model_name': model_name,
            'model_type': model_type,
            'loras': [],
            'strength_model': [],
            'config': config,
        }
        return (funmodels,)


class Wan2_2VaceFunSampler:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "funmodels": (
                    "FunModels", 
                ),
                "prompt": (
                    "STRING_PROMPT", 
                ),
                "negative_prompt": (
                    "STRING_PROMPT", 
                ),
                "video_length": (
                    "INT", {"default": 81, "min": 1, "max": 161, "step": 4}
                ),
                "base_resolution": (
                    [ 
                        512,
                        640,
                        768,
                        896,
                        960,
                        1024,
                    ], {"default": 640}
                ),
                "seed": (
                    "INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}
                ),
                "steps": (
                    "INT", {"default": 50, "min": 1, "max": 200, "step": 1}
                ),
                "cfg": (
                    "FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}
                ),
                "scheduler": (
                    ["Flow", "Flow_Unipc", "Flow_DPM++"],
                    {
                        "default": 'Flow'
                    }
                ),
                "shift": (
                    "INT", {"default": 12, "min": 1, "max": 100, "step": 1}
                ),
                "boundary": (
                    "FLOAT", {"default": 0.875, "min": 0.00, "max": 1.00, "step": 0.001}
                ),
                "teacache_threshold": (
                    "FLOAT", {"default": 0.10, "min": 0.00, "max": 1.00, "step": 0.005}
                ),
                "enable_teacache":(
                    [False, True],  {"default": True,}
                ),
                "num_skip_start_steps": (
                    "INT", {"default": 5, "min": 0, "max": 50, "step": 1}
                ),
                "teacache_offload":(
                    [False, True],  {"default": True,}
                ),
                "cfg_skip_ratio":(
                    "FLOAT", {"default": 0, "min": 0, "max": 1, "step": 0.01}
                ),
                "padding_in_subject_ref_images":(
                    [False, True],  {"default": True,}
                ),
            },
            "optional": {
                "control_video": ("IMAGE",),
                "start_image": ("IMAGE",),
                "end_image": ("IMAGE",),
                "subject_ref_images": ("IMAGE",),
                "riflex_k": ("RIFLEXT_ARGS",),
            },
        }
    
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES =("images",)
    FUNCTION = "process"
    CATEGORY = "CogVideoXFUNWrapper"

    def process(self, funmodels, prompt, negative_prompt, video_length, base_resolution, seed, steps, cfg, scheduler, shift, boundary, teacache_threshold, enable_teacache, num_skip_start_steps, teacache_offload, cfg_skip_ratio, padding_in_subject_ref_images, control_video=None, start_image=None, end_image=None, subject_ref_images=None, riflex_k=0):
        global transformer_cpu_cache
        global transformer_high_cpu_cache
        global lora_path_before
        global lora_high_path_before

        device = mm.get_torch_device()
        offload_device = mm.unet_offload_device()

        mm.soft_empty_cache()
        gc.collect()
        
        # Get Pipeline
        pipeline = funmodels['pipeline']
        model_name = funmodels['model_name']
        weight_dtype = funmodels['dtype']
        model_type = funmodels['model_type']

        # Count most suitable height and width
        aspect_ratio_sample_size    = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
        
        if control_video is not None and type(control_video) is str:
            original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
        elif control_video is not None:
            control_video = np.array(control_video.cpu().numpy() * 255, np.uint8)
            original_width, original_height = Image.fromarray(control_video[0]).size
        else:
            original_width, original_height = 384 / 512 * base_resolution, 672 / 512 * base_resolution

        if subject_ref_images is not None:
            subject_ref_images = [to_pil(_subject_ref_images) for _subject_ref_images in subject_ref_images]
            original_width, original_height = subject_ref_images[0].size if type(subject_ref_images) is list else Image.open(subject_ref_images).size
        
        if start_image is not None:
            start_image = [to_pil(_start_image) for _start_image in start_image]
            original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
        
        if end_image is not None:
            end_image = [to_pil(_end_image) for _end_image in end_image]

        closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
        height, width = [int(x / 16) * 16 for x in closest_size]

        # Load Sampler
        pipeline.scheduler = get_wan_scheduler(scheduler, shift)
        coefficients = get_teacache_coefficients(model_name) if enable_teacache else None
        if coefficients is not None:
            print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.")
            pipeline.transformer.enable_teacache(
                coefficients, steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload
            )
            if pipeline.transformer_2 is not None:
                pipeline.transformer_2.share_teacache(transformer=pipeline.transformer)
        else:
            pipeline.transformer.disable_teacache()
            if pipeline.transformer_2 is not None:
                pipeline.transformer_2.disable_teacache()

        if cfg_skip_ratio is not None:
            print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.")
            pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, steps)
            if pipeline.transformer_2 is not None:
                pipeline.transformer_2.share_cfg_skip(transformer=pipeline.transformer)

        generator= torch.Generator(device).manual_seed(seed)

        with torch.no_grad():
            video_length = int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1

            if riflex_k > 0:
                latent_frames = (video_length - 1) // pipeline.vae.config.temporal_compression_ratio + 1
                pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames)
                if pipeline.transformer_2 is not None:
                    pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames)

            inpaint_video, inpaint_video_mask, _ = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=(height, width))

            if subject_ref_images is not None:
                subject_ref_images = [get_image_latent(_subject_ref_image, sample_size=(height, width), padding=padding_in_subject_ref_images) for _subject_ref_image in subject_ref_images]
                subject_ref_images = torch.cat(subject_ref_images, dim=2)

            control_video, control_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width), fps=16, ref_image=None)

            # Apply lora
            if funmodels.get("lora_cache", False):
                if len(funmodels.get("loras", [])) != 0:
                    # Save the original weights to cpu
                    if len(transformer_cpu_cache) == 0:
                        print('Save transformer state_dict to cpu memory')
                        transformer_state_dict = pipeline.transformer.state_dict()
                        for key in transformer_state_dict:
                            transformer_cpu_cache[key] = transformer_state_dict[key].clone().cpu()
                    
                    lora_path_now = str(funmodels.get("loras", []) + funmodels.get("strength_model", []))
                    if lora_path_now != lora_path_before:
                        print('Merge Lora with Cache')
                        lora_path_before = copy.deepcopy(lora_path_now)
                        pipeline.transformer.load_state_dict(transformer_cpu_cache)
                        for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])):
                            pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype)
                   
                    if pipeline.transformer_2 is not None:
                        # Save the original weights to cpu
                        if len(transformer_high_cpu_cache) == 0:
                            print('Save transformer high state_dict to cpu memory')
                            transformer_high_state_dict = pipeline.transformer_2.state_dict()
                            for key in transformer_high_state_dict:
                                transformer_high_cpu_cache[key] = transformer_high_state_dict[key].clone().cpu()

                        lora_high_path_now = str(funmodels.get("loras_high", []) + funmodels.get("strength_model", []))
                        if lora_high_path_now != lora_high_path_before:
                            print('Merge Lora High with Cache')
                            lora_high_path_before = copy.deepcopy(lora_high_path_now)
                            pipeline.transformer_2.load_state_dict(transformer_cpu_cache)
                            for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])):
                                pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2")
            else:
                print('Merge Lora')
                # Clear lora when switch from lora_cache=True to lora_cache=False.
                if len(transformer_cpu_cache) != 0:
                    pipeline.transformer.load_state_dict(transformer_cpu_cache)
                    transformer_cpu_cache = {}
                    lora_path_before = ""
                    gc.collect()
                
                for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])):
                    pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype)

                # Clear lora when switch from lora_cache=True to lora_cache=False.
                if pipeline.transformer_2 is not None:
                    if len(transformer_high_cpu_cache) != 0:
                        pipeline.transformer_2.load_state_dict(transformer_high_cpu_cache)
                        transformer_high_cpu_cache = {}
                        lora_high_path_before = ""
                        gc.collect()

                    for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])):
                        pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2")

            sample = pipeline(
                prompt, 
                num_frames = video_length,
                negative_prompt = negative_prompt,
                height      = height,
                width       = width,
                generator   = generator,
                guidance_scale = cfg,
                num_inference_steps = steps,

                video               = inpaint_video,
                mask_video          = inpaint_video_mask,
                control_video       = control_video,
                subject_ref_images  = subject_ref_images,
                boundary            = boundary,
                comfyui_progressbar = True,
            ).videos
            videos = rearrange(sample, "b c t h w -> (b t) h w c")

            if not funmodels.get("lora_cache", False):
                print('Unmerge Lora')
                for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])):
                    pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype)
                if pipeline.transformer_2 is not None:
                    for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])):
                        pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2")
        return (videos,)   